Skip to content

Unroll _transform_tuple to fix Enzyme.autodiff on tuples of length ≥ 33#170

Open
jlperla wants to merge 1 commit into
tpapp:masterfrom
jlperla:enzyme-3104-unroll-transform-tuple
Open

Unroll _transform_tuple to fix Enzyme.autodiff on tuples of length ≥ 33#170
jlperla wants to merge 1 commit into
tpapp:masterfrom
jlperla:enzyme-3104-unroll-transform-tuple

Conversation

@jlperla

@jlperla jlperla commented May 15, 2026

Copy link
Copy Markdown

Replace the Base.tail-recursive _transform_tuple with a @generated straight-line unroll — same outputs bit-for-bit, but the typed IR no longer contains a self-invoke, which is what Enzyme.autodiff (Forward and Reverse) trips on at tuple length ≥ 33 with AssertionError("conv == 37") (EnzymeAD/Enzyme.jl#3104).

The recursive Base.tail fold in _transform_tuple makes Enzyme.autodiff
(Forward and Reverse) throw `AssertionError("conv == 37")` from
Enzyme/src/rules/jitrules.jl:2073 once the tuple has ≥ 33 entries
(EnzymeAD/Enzyme.jl#3104). Replace it with a @generated straight-line
unroll that produces the same outputs bit-for-bit while emitting no
self-invoke in the typed IR — which is what Enzyme trips on.

Verified against the full Pkg.test() suite (all Pass = Total) and a
35-entry SW07-Pfeifer-style NamedTuple prior (fwd + rev both succeed).
@scheidan

scheidan commented Jun 15, 2026

Copy link
Copy Markdown
Contributor

Any chance that this will be merged here? I understand that the real fix should be on Enzyme's side, but that may be much harder.

Thanks!

PS: The is my real world MWE that lead me finally to this PR; maybe it is useful for someone.

using Distributions
using Enzyme
using TransformVariables

N = 33
dists = ntuple(i -> LogNormal(0.0, 1.0), N)
dists = NamedTuple{ntuple(i -> Symbol("x", i), N)}(dists)

function prior_transform(priors)
    transforms = map(priors) do prior
        left, right = extrema(support(prior))
        left = isinf(left) ? -TransformVariables.∞ : left
        right = isinf(right) ? TransformVariables.∞ : right
        TransformVariables.as(Real, left, right)
    end
    TransformVariables.as(transforms)
end

trans = prior_transform(dists)
q = fill(-0.1, TransformVariables.dimension(trans))

foo(q) = sum(values(TransformVariables.transform(trans, q)))

Enzyme.gradient(Enzyme.Reverse, foo, q) # AssertionError: conv == 37

@tpapp

tpapp commented Jun 15, 2026

Copy link
Copy Markdown
Owner

@jlperla, thanks for this, @scheidan, thanks for the ping. I apologize for the delay in reviewing this.

It is not strictly equivalent as, AFAIK, built-ins do not necessarily unroll above a certain tuple length. But given that the intention of using a tuple is to get type-stable code, I don't see a problem with this here. Also, EnzymeAD/Enzyme.jl#3104 indicates that this is an issue on the Julia side, so fixing it on our end may be the best option for now.

@devmotion, this is fine with me, do you have any comments?

@tpapp

tpapp commented Jun 15, 2026

Copy link
Copy Markdown
Owner

(closing and reopening to make CI run)

@tpapp tpapp closed this Jun 15, 2026
@tpapp tpapp reopened this Jun 15, 2026
Comment thread src/aggregation.jl
Comment on lines +396 to +399
Implemented as a `@generated` straight-line unroll over the static tuple length.
Equivalent to the natural `Base.tail` recursion, but emits non-recursive code
so that `Enzyme.autodiff` does not hit `AssertionError("conv == 37")` on
tuples of length ≥ 33 (EnzymeAD/Enzyme.jl#3104).

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems very internal for being part of a docstrings? It also might change again in case of upstream compiler or Enzyme changes.

Suggested change
Implemented as a `@generated` straight-line unroll over the static tuple length.
Equivalent to the natural `Base.tail` recursion, but emits non-recursive code
so that `Enzyme.autodiff` does not hit `AssertionError("conv == 37")` on
tuples of length 33 (EnzymeAD/Enzyme.jl#3104).

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an internal helper function anyway and not part of the API. (I like to document my internal functions too, I know this is not common to do so). As far as I am concerned this is fine.

Comment thread src/aggregation.jl
for i in 1:N]
ℓ_sum = foldl((a, b) -> :($a + $b), ℓs)
return quote
idx = index

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is a separate idx variable needed? Couldn't we just operate with index?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants